from dataclasses import dataclass
import torch

#from Data_generator import get_random_problems, get_random_eval_problems


@dataclass
class Reset_State:
    problems: torch.Tensor
    # shape: (batch, node, node)


@dataclass
class Step_State:
    BATCH_IDX: torch.Tensor
    POMO_IDX: torch.Tensor
    # shape: (batch, pomo)
    current_node: torch.Tensor = None
    # shape: (batch, pomo)
    ninf_mask: torch.Tensor = None
    # shape: (batch, pomo, node)
    machine_time: torch.Tensor=None


class PFSPEnv:
    def __init__(self, **env_params):

        # Const @INIT
        ####################################
        self.env_params = env_params
        self.pomo_size = env_params['pomo_size']
        self.n_jobs = env_params['job_cnt']
        # Const @Load_Problem
        ####################################
        self.batch_size = None
        self.BATCH_IDX = None
        self.POMO_IDX = None
        # IDX.shape: (batch, pomo)
        self.problems = None
        # shape: (batch, node, node)

        # Dynamic
        ####################################
        self.selected_count = None
        self.current_node = None
        # shape: (batch, pomo)
        self.selected_node_list = None
        # shape: (batch, pomo, 0~)

        # STEP-State
        ####################################
        self.step_state = None

    def load_problems(self, batch_size, proj_type):
        self.batch_size = batch_size
        self.BATCH_IDX = torch.arange(self.batch_size)[:, None].expand(self.batch_size, self.pomo_size)
        self.POMO_IDX = torch.arange(self.pomo_size)[None, :].expand(self.batch_size, self.pomo_size)

        # if proj_type== 'train':
        #     self.problems = get_random_problems(batch_size, self.n_jobs, self.n_mc, self.mode)
        # else:
        #     self.problems = get_random_eval_problems(batch_size, self.n_jobs, self.n_mc, self.mode, seed=1235)
        # shape: (batch, job, mc)

    def load_problems_manual(self, problems):
        # problems.shape: (batch, node, node)

        self.batch_size = problems.size(0)
        self.BATCH_IDX = torch.arange(self.batch_size)[:, None].expand(self.batch_size, self.pomo_size)
        self.POMO_IDX = torch.arange(self.pomo_size)[None, :].expand(self.batch_size, self.pomo_size)
        self.problems = problems
        # shape: (batch, node, node)

    def reset(self):
        self.selected_count = 0
        self.current_node = None
        # shape: (batch, pomo)
        self.selected_node_list = torch.empty((self.batch_size, self.pomo_size, 0), dtype=torch.long)
        # shape: (batch, pomo, 0~)

        self._create_step_state()

        reward = None
        done = False
        return Reset_State(problems=self.problems), reward, done

    def _create_step_state(self):
        self.step_state = Step_State(BATCH_IDX=self.BATCH_IDX, POMO_IDX=self.POMO_IDX)
        self.step_state.ninf_mask = torch.zeros((self.batch_size, self.pomo_size, self.n_jobs))
        # shape: (batch, pomo, node)

    def pre_step(self):
        reward = None
        done = False
        return self.step_state, reward, done

    def step(self, node_idx):
        # node_idx.shape: (batch, pomo)

        self.selected_count += 1

        self.current_node = node_idx
        # shape: (batch, pomo)
        self.selected_node_list = torch.cat((self.selected_node_list, self.current_node[:, :, None]), dim=2)
        # shape: (batch, pomo, 0~node)

        self._update_step_state()
        
        # returning values
        done = (self.selected_count == self.n_jobs)
        if done:
            #print("self.problem : ", self.problems)
            #print("self.selected_node_list : ", self.selected_node_list.shape)
            reward = -self.compute_TWT(self.problems, self.selected_node_list )  # Note the MINUS Sign ==> We MAXIMIZE reward
            # shape: (batch, pomo)
            #plot_gantt_chart(self.selected_node_list, s_t, done_t)
        else:    
            reward = None
        return self.step_state, reward, done

    def _update_step_state(self):
        self.step_state.current_node = self.current_node
        # shape: (batch, pomo)
        self.step_state.ninf_mask[self.BATCH_IDX, self.POMO_IDX, self.current_node] = float('-inf')
        # shape: (batch, pomo, n_mc*2)

    
    def compute_TWT(self, problems, selected):
        jobs = problems.repeat_interleave(self.pomo_size, dim=0).view(self.batch_size, self.pomo_size, self.n_jobs,-1)
        B, P, N, D = jobs.shape
        assert D == 3, "Jobs tensor should have last dimension size 3 (processing_time, due_date, weight)"
        assert selected.shape == (B, P, N), "Selected tensor should have shape [B, P, N]"

        # Step 1: Gather jobs in the selected order
        # Expand selected to have the same number of dimensions as jobs for torch.gather
        selected_expanded = selected.unsqueeze(-1).expand(-1, -1, -1, D)  # [B, P, N, 3]
        jobs_selected = torch.gather(jobs, dim=2, index=selected_expanded)  # [B, P, N, 3]

        # Step 2: Extract processing times, due dates, and weights
        processing_times = jobs_selected[..., 0]  # [B, P, N]
        due_dates = jobs_selected[..., 1]         # [B, P, N]
        weights = jobs_selected[..., 2]           # [B, P, N]

        # Step 3: Compute cumulative processing times (completion times)
        completion_times = torch.cumsum(processing_times, dim=2)  # [B, P, N]

        # Step 4: Compute tardiness
        tardiness = torch.relu(completion_times - due_dates)  # [B, P, N]

        # Step 5: Compute weighted tardiness
        weighted_tardiness = tardiness * weights  # [B, P, N]

        # Step 6: Sum weighted tardiness over all jobs to get TWT
        total_weighted_tardiness = torch.sum(weighted_tardiness, dim=2)  # [B, P]

        return total_weighted_tardiness

